import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
#from utils import train
import torchvision.models as models
import time
import matplotlib.pyplot as plt
import cvxpy as cvx
import scipy.io as scio
time_start=time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

import scipy.io as scio   

data1=scio.loadmat('./data/oc_feature/decaf/amazon_decaf.mat') 
data2=scio.loadmat('./data/oc_feature/decaf/caltech_decaf.mat') 
data3=scio.loadmat('./data/oc_feature/decaf/dslr_decaf.mat') 
data4=scio.loadmat('./data/oc_feature/decaf/webcam_decaf.mat') 
x1=data1['feas']
y1=data1['labels'].reshape((958))
x2=data2['feas']
y2=data2['labels'].reshape((1123))
x3=data3['feas']
y3=data3['labels'].reshape((157))
x4=data4['feas']
y4=data4['labels'].reshape((295))
print('data ok')  

for i in range(10):
    y1[y1==(i+1)]=i
for i in range(10):
    y2[y2==(i+1)]=i
for i in range(10):
    y3[y3==(i+1)]=i
for i in range(10):
    y4[y4==(i+1)]=i
#feature---------------
class Net_f(nn.Module):
    def __init__(self):
        super(Net_f, self).__init__()
        self.fc1 = nn.Linear(4096,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self,x):
        out=F.relu(self.fc1(x))
        out=self.fc2(out)
        return out       


class Net_g(nn.Module):
    def __init__(self,num_class=10, dim=10):
        super(Net_g, self).__init__()

        self.fc=nn.Linear(num_class, dim)

    def forward(self,x):
        out=self.fc(x)

        return out

def corr(f,g):
    k = torch.mean(torch.sum(f*g,1))
    return k
    
def cov_trace(f,g):
    cov_f = torch.mm(torch.t(f),f) / (f.size()[0]-1.)
    cov_g = torch.mm(torch.t(g),g) / (g.size()[0]-1.)
    return torch.trace(torch.mm(cov_f, cov_g))

def neg_hscore(f,g):
    f0 = f - torch.mean(f,0)
    g0 = g - torch.mean(g,0)
    corr = torch.mean(torch.sum(f0*g0,1))
    cov_f = torch.mm(torch.t(f0),f0) / (f0.size()[0]-1.)
    cov_g = torch.mm(torch.t(g0),g0) / (g0.size()[0]-1.)
    return - corr + torch.trace(torch.mm(cov_f, cov_g)) / 2.

lr=0.0005
epoch=40
ind=0
model_f = Net_f().to(device)
model_g = Net_g().to(device)
optimizer_fg = torch.optim.Adam(list(model_f.parameters())+list(model_g.parameters()),lr=lr)
losslist=[]
acclist=[0]
alpha=[0.6,0.4]
refx_train=x2[y2==0][0:3]
refy_train=[0,0,0]
for i in range(1,10):
    refx_train=np.vstack((refx_train,x2[y2==i][0:3]))
    refy_train=np.append(refy_train,[i,i,i])
refy_train=refy_train.astype(int)
sourcex_train=x1[y1==0][0:20]
sourcey_train=0*np.ones(20)
for i in range(1,10):
    sourcex_train=np.vstack((sourcex_train,x1[y1==i][0:20]))
    sourcey_train=np.append(sourcey_train,i*np.ones(20))
sourcey_train=sourcey_train.astype(int)

samples_ref=torch.from_numpy(refx_train)
labels_ref=torch.from_numpy(refy_train)
labels_one_hot_ref = torch.zeros(len(labels_ref), 10).scatter_(1, labels_ref.view(-1,1), 1)
samples_trans=torch.from_numpy(sourcex_train)
labels_trans=torch.from_numpy(sourcey_train)
labels_one_hot_trans= torch.zeros(len(labels_trans), 10).scatter_(1, labels_trans.view(-1,1), 1)
for i in range(epoch):
    model_f.train()
    model_g.train()
    
    f_ref=model_f(Variable(samples_ref).float().to(device))
    g_ref=model_g(Variable(labels_one_hot_ref).float().to(device))
    f0_ref = f_ref - torch.mean(f_ref,0)
    g0_ref = g_ref - torch.mean(g_ref,0)
    f_trans=model_f(Variable(samples_trans).float().to(device))-torch.mean(f_ref,0)
    g_trans=model_g(Variable(labels_one_hot_trans).float().to(device))- torch.mean(g_ref,0)
    optimizer_fg.zero_grad()
    
    loss=(-2)*alpha[0]*corr(f0_ref,g0_ref)
    loss+=(-2)*alpha[1]*corr(f_trans,g_trans)
    loss+=2*((torch.sum(f0_ref,0)/f0_ref.size()[0])*(torch.sum(g0_ref,0)/g0_ref.size()[0])).sum()
    loss+=cov_trace(f0_ref,g0_ref)
    losslist.append(loss.item())
    loss.backward()
    optimizer_fg.step()
    ind+=1
    print(ind)
#------acc
    model_f.eval()
    model_g.eval()
    fc = model_f(Variable(samples_trans).float().to(device)).data.cpu().numpy()
    f_mean = np.sum(fc,axis=0)/fc.shape[0]
    labellist = torch.Tensor(np.eye(10))
    gc = model_g(Variable(labellist).to(device)).data.cpu().numpy()
    gce = np.sum(gc,axis=0)/gc.shape[0]
    gcp = gc-gce

    samples_test=torch.from_numpy(x2)
    labels_test = y2
    fc=model_f(Variable(samples_test).float().to(device)).data.cpu().numpy()
    fcp=fc-f_mean
    fgp=np.dot(fcp,gcp.T)
    acc = (np.argmax(fgp, axis = 1) == labels_test).sum()
    total = len(samples_test)
    acc=acc/total
    print(acc)
    if acc > 0.2:
       if acc > (max(acclist)):
           paraf=model_f.state_dict()
           parag=model_g.state_dict()
           print('changepara')
           finalacc=acc
    acclist.append(acc)





print(finalacc)
#torch.save(paraf, './mpara/ocf_atoc.pth')
#torch.save(parag, './mpara/ovg_atoc.pth')
time_end=time.time()
print(time_end-time_start)
